from copy import deepcopy
import numpy as np
import torch
from torch.optim import Adam
from numpy import linalg as LA
import gym
import argparse
import json
from utils import redirect_stdout
import torch.nn as nn
import itertools
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
from spinup_copy import mpi_avg, mpi_statistics_scalar, num_procs, setup_pytorch_for_mpi, sync_params, mpi_avg_grads
import scipy.signal
from gym.spaces import Box, Discrete
import os, time
import random
from actor_critic import TD3, DDPG, SAC
import matplotlib.pyplot as plt
from gym.wrappers.monitoring.video_recorder import VideoRecorder
import matplotlib.colors as colr
import matplotlib.pyplot as plt

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--alg', type=str, default='td3')
    parser.add_argument('--env', type=str, default='HalfCheetah-v2')
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--pfm', type=int, default=6007)
    parser.add_argument('--adv', action='store_true')
    parser.add_argument('--name', type=str, default='temp')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    agent = eval(args.alg.upper())(env_name=args.env, ac_kwargs=dict(hidden_sizes=[args.hid] * args.l), gamma=args.gamma)

    agent.env.seed(args.seed)

    pfm = args.pfm
    if args.alg == 'ddpg':
        if not args.adv:
            agent.ac.pi.load_state_dict(
                torch.load('../models/%s_model/%s_pi_%d' % (args.alg, args.env, pfm)))
            agent.ac.q.load_state_dict(
                torch.load('../models/%s_model/%s_q_%d' % (args.alg, args.env, pfm)))
        else:
            agent.ac.pi.load_state_dict(
                torch.load('../adv_models_3/%s_model/%s_pi_%d' % (args.alg, args.env, pfm)))
            agent.ac.q.load_state_dict(
                torch.load('../adv_models_3/%s_model/%s_q_%d' % (args.alg, args.env, pfm)))
    elif args.alg == 'td3' or 'sac':
        if not args.adv:
            agent.ac.pi.load_state_dict(
                torch.load('../models/%s_model/%s_pi_%d' % (args.alg, args.env, pfm)))
            agent.ac.q1.load_state_dict(
                torch.load('../models/%s_model/%s_q1_%d' % (args.alg, args.env, pfm)))
            agent.ac.q2.load_state_dict(
                torch.load('../models/%s_model/%s_q2_%d' % (args.alg, args.env, pfm)))
        else:
            agent.ac.pi.load_state_dict(
                torch.load('../adv_models_3/%s_model/%s_pi_%d' % (args.alg, args.env, pfm)))
            agent.ac.q1.load_state_dict(
                torch.load('../adv_models_3/%s_model/%s_q1_%d' % (args.alg, args.env, pfm)))
            agent.ac.q2.load_state_dict(
                torch.load('../adv_models_3/%s_model/%s_q2_%d' % (args.alg, args.env, pfm)))


    o = agent.env.reset()
    total_r = 0
    video = VideoRecorder(agent.env, '../videos/%s.mp4' %(args.name))
    for t in range(1000):
        # unset LD_PRELOAD
        video.capture_frame()
        a = agent.get_action(o)
        o2, r, d, _ = agent.env.step(a)
        total_r += r
        o = o2
        if d :
            o = agent.env.reset()
            print(total_r)
            r=0
            total_r=0

    agent.env.close()